library(tidyverse)
# remotes::install_github("nickbrazeau/discent", ref = "develop")
library(discent)
source("R/polysim_wrappers.R")
source("R/discent_wrappers.R")
source("R/utils.R")

Overview

Model conceptualization: \[ r_{ij} = \left(\frac{f_i + f_j}{2}\right) exp\left\{ \frac{-d_{ij}}{M} \right\} \]

Coding Decisions: 1. M and F parameters have different learning rates (uncoupled) 2. M must be positive 3. ADAM gradient descent optimizer

Goals

  1. Model Stability
  2. Model accuracy/realism

Simulated Data

Using polySimIBD have five frameworks of simulation, which for each we ran 100 independent realizations (i.e. same migration, same Ne = different seed). For each framework, there are 25 demes laid out on a 5x5 square plane. Frameworks are as follows:

  1. Isolation by Distance, Ne Constant (migrations are \(\frac{1}{dist}\))
  2. Lattice (can only move to neighbors, no diagonal moves; bounded by borders), Ne Constant
  3. Torus (as above, but no longitudinal boundaries = cylinder), Ne Constant
  4. IBDist, Ne: 5,25,50,100,250 across rows of 5x5 square matrix
  5. Isolation by Distance with wrong distances (to analyze cost)

Model Stability

Goal to find the best start parameters for each framework. Selected first realization and made assumption that start parameters would be consistent across realizations within each framework. However, this also provides the opportunity to determine how much start parameters effect final estimates of DISCs (called final_fis in code) and final global migration rates (called final_ms). Our search grid lookover the following 11220 combinations:

# look up tables
fstartsvec <- seq(from = 0.1, to = 0.9, by = 0.05)
mstartsvec <- 10^-seq(1,6)
flearnsvec <- 10^-seq(1,10)
mlearnsvec <- 10^-seq(5,15)
search_grid <- expand.grid(fstartsvec, mstartsvec, flearnsvec, mlearnsvec)
colnames(search_grid) <- c("fstart", "mstart", "f_learn", "m_learn")
#............................................................
# Part 1: Model Stability
#...........................................................
# read in model framework
migmatdf <- readRDS("simdata/simulation_setup/inputs/migmat_framework.RDS")
  
coordgrid <- readRDS("simdata/simulation_setup/inputs/squarecoords.rds")

# read in search grid 
search_grid_fullraw <- readRDS("disc_results/search_grid_full_for_discdat.RDS")

Spectrum of Costs

search_grid_full <- search_grid_fullraw %>%
  dplyr::mutate(finalcost = purrr::map_dbl(discret, extract_final_cost),
                final_fs = purrr::map(discret, get_fs),
                final_ms = purrr::map_dbl(discret, get_ms))

search_grid_full %>% 
  ggplot() + 
  geom_boxplot(aes(x = modname, y = finalcost)) +
  theme_linedraw() +
  ggtitle("All Costs over Model Frameworks")

We will downsample to reasonable costs.

search_grid_full <- search_grid_full %>% 
  dplyr::filter(finalcost < 1e5)
search_grid_full %>% 
  tidyr::unnest(cols = "final_fs") %>% 
  ggplot() + 
  geom_point(aes(x = finalcost, y = Final_Fis,
                 color = rep)) +
  theme_linedraw() +
  scale_color_viridis_c() + 
  facet_wrap(~modname) +
  ggtitle("Costs vs Final Fs") +
  theme(legend.position = "none",
        axis.text.x = element_text(angle = 90))

search_grid_full %>% 
  tidyr::unnest(cols = "final_ms") %>% 
  ggplot() + 
  geom_point(aes(x = finalcost, y = final_ms,
                 color = rep)) +
  theme_linedraw() +
  scale_color_viridis_c() + 
  facet_wrap(~modname) +
  ggtitle("Costs vs Final Ms") +
  theme(legend.position = "none",
        axis.text.x = element_text(angle = 90))

search_grid_full %>% 
  tidyr::unnest(cols = c("final_fs", "final_ms")) %>% 
  ggplot() + 
  geom_point(aes(x = Final_Fis, y = final_ms,
                 color = finalcost)) +
  theme_linedraw() +
  scale_color_viridis_c() + 
  facet_wrap(~modname) +
  ggtitle("Final Fs vs Final Ms") +
  theme(axis.text.x = element_text(angle = 90))

Model Accuracy

Note, here we are running the same start parameters identified from search grid above for our 100 realizations for each model framework. Therefore, there are 5x100 final global M estimations and 5x25x100 final Fi estimations.

# purely for memory
rm(search_grid_fullraw)
rm(search_grid_full)
gc()

#............................................................
# Part 2: Model Accuracy
#...........................................................
discretsraw <- readRDS("disc_results/final_discresults_for_discdat.RDS")


discrets <- discretsraw %>%
  dplyr::mutate(finalcost = purrr::map_dbl(discret, extract_final_cost),
                final_fs = purrr::map(discret, get_fs),
                final_ms = purrr::map_dbl(discret, get_ms))

Consistency of Estimates

discrets %>% 
  ggplot() + 
  geom_boxplot(aes(x = modname, y = finalcost)) +
  theme_linedraw() +
  ggtitle("Frameworks vs Costs") +
  theme(legend.position = "none",
        axis.text.x = element_text(angle = 90))

discrets %>% 
  ggplot() + 
  geom_boxplot(aes(x = modname, y = final_ms)) +
  theme_linedraw() +
  ggtitle("Frameworks vs Final Ms") +
  theme(legend.position = "none",
        axis.text.x = element_text(angle = 90))

discrets %>% 
  ggplot() + 
  geom_point(aes(x = finalcost, y = final_ms), alpha = 0.5) +
  theme_linedraw() +
  facet_grid(rows = vars(modname), scales = "free") +
  scale_color_viridis_c() +
  ggtitle("Model Frameworks vs Final Ms") +
  theme(axis.text.x = element_text(angle = 45, hjust = 1))

discrets %>% 
  tidyr::unnest(cols = "final_fs") %>% 
  ggplot() + 
  ggbeeswarm::geom_beeswarm(aes(x = Deme, y = Final_Fis, color = finalcost), alpha = 0.5) +
  theme_linedraw() +
  facet_grid(rows = vars(modname)) +
  scale_color_viridis_c() +
  ggtitle("Demes vs Final Fs") +
  theme(axis.text.x = element_text(angle = 90))

Estimates Over Space

coordgrid$deme <- as.character(coordgrid$deme )
discretsum <- discrets %>% 
  tidyr::unnest(cols = "final_fs") %>% 
  dplyr::group_by(modname, Deme) %>% 
  dplyr::summarise(
    meanFi = mean(Final_Fis),
    varFi = var(Final_Fis),
    minFi = min(Final_Fis),
    maxFi = max(Final_Fis)
  ) %>% 
  dplyr::rename(deme = Deme) %>% 
  dplyr::left_join(., coordgrid, by = "deme") 

discretsum %>% 
  ggplot() + 
  geom_point(aes(x = longnum, y = latnum, 
                 color = meanFi), size = 5) + 
  geom_text(aes(x = longnum, y = latnum, label = deme),
            color = "#ffffff", size = 3) + 
  scale_color_viridis_c("Mean DISCs") +
  facet_grid(rows = vars(modname)) +
  ylim(c(-2, 24)) +
  ggtitle("Mean Fis") +
  theme_linedraw() +
  theme(axis.text = element_blank())

discretsum %>% 
  ggplot() + 
  geom_point(aes(x = longnum, y = latnum, 
                 color = varFi), size = 5) + 
  geom_text(aes(x = longnum, y = latnum, label = deme),
            color = "#ffffff", size = 3) + 
  scale_color_viridis_c("var DISCs") +
  facet_grid(rows = vars(modname)) +
  ylim(c(-2, 24)) +
  ggtitle("Variance Fis") +
  theme_linedraw() +
  theme(axis.text = element_blank())

discretsum %>% 
  ggplot() + 
  geom_point(aes(x = longnum, y = latnum, 
                 color = minFi), size = 5) + 
  geom_text(aes(x = longnum, y = latnum, label = deme),
            color = "#ffffff", size = 3) + 
  scale_color_viridis_c("Min DISCs") +
  facet_grid(rows = vars(modname)) +
  ylim(c(-2, 24)) +
  ggtitle("Min Fis") +
  theme_linedraw() +
  theme(axis.text = element_blank())

discretsum %>% 
  ggplot() + 
  geom_point(aes(x = longnum, y = latnum, 
                 color = maxFi), size = 5) + 
  geom_text(aes(x = longnum, y = latnum, label = deme),
            color = "#ffffff", size = 3) + 
  scale_color_viridis_c("max DISCs") +
  facet_grid(rows = vars(modname)) +
  ylim(c(-2, 24)) +
  ggtitle("max Fis") +
  theme_linedraw() +
  theme(axis.text = element_blank())

Costs Estimates for Fit

Here we will analyze whether wrong geographic distances results in worst cost.

discrets_dist_consider <- discrets %>% 
  dplyr::filter(modname %in% c("IsoByDist", "badIsoByDist"))

discrets_dist_consider %>% 
  ggplot() + 
  geom_boxplot(aes(x = modname, y = finalcost)) +
  theme_linedraw() +
  ggtitle("Costs over IBD with Correct and Misspecified Distance")